{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 一.变分EM算法\n",
    "在介绍LDA的变分EM实现之前，首先我们要弄懂什么是变分EM，变分推断我们在之前提过，EM算法也在之前提过，可是将它俩凑一起似乎就不认识了呢....，这里还请大家回去看一下15章的这俩小结[《变分推断的原理推导》](https://nbviewer.jupyter.org/github/zhulei227/ML_Notes/blob/master/notebooks/15_01_VI_%E5%8F%98%E5%88%86%E6%8E%A8%E6%96%AD%E7%9A%84%E5%8E%9F%E7%90%86%E6%8E%A8%E5%AF%BC.ipynb)以及[《变分推断与EM的关系》](https://nbviewer.jupyter.org/github/zhulei227/ML_Notes/blob/master/notebooks/15_02_VI_%E5%8F%98%E5%88%86%E6%8E%A8%E6%96%AD%E4%B8%8EEM%E7%9A%84%E5%85%B3%E7%B3%BB.ipynb) 看完之后，也许就能猜出变分EM是要怎么做了，下面我再做一个简单的说明\n",
    "\n",
    "![avatar](./source/15_EM中三者间的关系.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "如上图，我们知道带参数的对数似然函数$ln\\ p(X\\mid\\theta)$，可以拆解为一个证据下界ELBO函数$L(q,\\theta)$和一个KL距离$KL(q||p)$，这里的$q$即是我们的变分分布，为了使它更加简单通常对其做一个平均场假设，即各个隐变量（组）之间是独立的：   \n",
    "\n",
    "$$\n",
    "q(z)=q(z_1)q(z_2)\\cdots q(z_n)\n",
    "$$  \n",
    "\n",
    "而$p$则是复杂的后验概率分布：   \n",
    "\n",
    "$$\n",
    "p(z)=p(z\\mid X,\\theta)\n",
    "$$  \n",
    "\n",
    "好的，在此基础上，我们来看看变分推断和EM分别要做怎么样的一件事：   \n",
    "\n",
    "#### 变分推断\n",
    "变分推断要做的事情是让简单的变分分布去近似复杂的后验分布：   \n",
    "\n",
    "$$\n",
    "q(z)\\rightarrow p(z)\n",
    "$$  \n",
    "\n",
    "它并不关心$\\theta$，将其视作一个常数处理，通过最大化ELBO函数来使得$KL(q||p)$最小化，从而使$q$与$p$近似，即它要做的是如下的优化问题：   \n",
    "\n",
    "$$\n",
    "q^*=arg\\max_{q}L(q,\\theta)\n",
    "$$  \n",
    "\n",
    "#### EM\n",
    "而EM算法的初衷是要通过优化$\\theta$使得对数似然函数极大化，它令：   \n",
    "\n",
    "$$\n",
    "q(z)=p(z\\mid X,\\theta^{old})\n",
    "$$  \n",
    "\n",
    "这时$KL(q||p)=0$，所以有：   \n",
    "\n",
    "$$\n",
    "L(q,\\theta)=ln\\ p(X\\mid \\theta)\n",
    "$$  \n",
    "\n",
    "这时，对ELBO函数极大化等价于对对数似然函数极大化，从而得到最优解：   \n",
    "\n",
    "$$\n",
    "\\theta=arg\\max_{\\theta}L(q\\mid\\theta),q(z)=p(z\\mid X,\\theta^{old})\n",
    "$$  \n",
    "\n",
    "显然，使用EM算法的前提是后验概率分布$p(z\\mid X,\\theta)$的形式比较方便求解，如果它很复杂呢？那变分EM就诞生了...\n",
    "\n",
    "#### 变分EM\n",
    "\n",
    "变分EM不改EM的初衷，即要使得对数似然函数$ln\\ p(X\\mid \\theta)$极大化，同时对于$p(z\\mid X,\\theta)$利用一个简单的变分分布$q(z)$去近似，所以变分EM算法的优化变量包括两个：$q$和$\\theta$   \n",
    "\n",
    "$$\n",
    "q^*,\\theta^*=arg\\max_{q,\\theta}L(q,\\theta)\n",
    "$$   \n",
    "而优化过程通常采用坐标轮换法，即：   \n",
    "\n",
    "（1）E步：固定$\\theta$，求$L(q,\\theta)$对$q$的最大化；   \n",
    "（2）M步：固定$q$，求$L(q,\\theta)$对$\\theta$的最大化；  \n",
    "\n",
    "如果对变分EM没问题了，接下来就要考虑利用变分EM去求LDA模型的什么分布勒？回想一下上一节的Gibbs采样，我们去求相同的分布不就可以了，但LDA的变分EM求解还做了进一步的简化"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 二.ELBO推导\n",
    "变分EM对LDA模型做了如下简化\n",
    "![avatar](./source/16_变分EM的LDA模型.png)\n",
    "省略了超参数$\\beta$，其中$\\alpha,\\varphi$为模型参数，$\\theta,z$是隐变量，$w$是可观测变量，为了简便，一次只考虑一个文本，记作$w=(w_1,w_2,...,w_N)$，对应的主题序列$z=(z_1,z_2,...,z_N)$，对应的话题分布为$\\theta$，所以其联合概率分布可以表示为：   \n",
    "\n",
    "$$\n",
    "p(\\theta,z,w\\mid\\alpha,\\varphi)=p(\\theta\\mid\\alpha)\\prod_{n=1}^Np(z_n\\mid\\theta)p(w_n\\mid z_n,\\varphi)\n",
    "$$  \n",
    "\n",
    "所以，我们需要去近似后验概率$p(\\theta,z\\mid w,\\alpha,\\varphi)$，可以定义变分分布为：   \n",
    "\n",
    "$$\n",
    "q(\\theta,z\\mid\\gamma,\\eta)=q(\\theta\\mid\\gamma)\\prod_{n=1}^Nq(z_n\\mid\\eta_n)\n",
    "$$  \n",
    "\n",
    "其中，$\\gamma=(\\gamma_1,\\gamma_2,...,\\gamma_K)$是狄利克雷分布参数，$\\eta=(\\eta_1,\\eta_2,...,\\eta_n)$是多项分布参数，变量$\\theta$和$z$的各个分量都是条件独立的，它的盘子图如下：   \n",
    "![avatar](./source/16_变分EM的变分分布.png)  \n",
    "\n",
    "所以，其证据下界ELBO可以写作：   \n",
    "\n",
    "$$\n",
    "L(\\gamma,\\eta,\\alpha,\\varphi)=E_q[ln\\ p(\\theta,z,w\\mid\\alpha,\\varphi)]-E_q[ln\\ q(\\theta,z\\mid\\gamma,\\eta)]\n",
    "$$  \n",
    "\n",
    "其中数学期望是对分布$q(\\theta,z\\mid\\gamma,\\eta)$定义的，为了方便简写为$E_q[\\cdot]$，$\\gamma$和$\\eta$是变分分布的参数，$\\alpha$和$\\varphi$是LDA模型的参数"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 三.参数求解\n",
    "\n",
    "所以接下来，就是固定$\\gamma,\\eta$，求$L(\\gamma,\\eta,\\alpha,\\varphi)$对$\\alpha,\\varphi$的极大化，然后固定$\\alpha,\\varphi$，求$L(\\gamma,\\eta,\\alpha,\\varphi)$对$\\gamma,\\eta$的极大化，持续下去直到收敛...推导过程的公式不想码了，哈哈哈，自己看《统计学习方法》吧，下面就直接写参数求解的算法了，并对必要的符号做说明\n",
    "\n",
    "#### 算法一：对变分参数$\\gamma,\\eta$估计\n",
    ">初始化：对所有的$k$和$n$，$\\eta_{nk}^{(0)}=1/K$（$k$表示主题，$n$表示当前文本的第$n$个位置，$K$表示总主题数）  \n",
    "\n",
    ">初始化：对所有的$k$，$\\gamma_k=\\alpha_k+N/K$（$k$表示主题，$K$表示总主题数，$N$表示当前文本的总字数）  \n",
    "\n",
    ">重复\n",
    ">>对$n=1:N$\n",
    ">>>对$k=1:K$\n",
    ">>>>$\\eta_{nk}^{(t+1)}=\\varphi_{kv}exp\\left[\\Psi(\\gamma_k^{(t)})-\\Psi(\\sum_{l=1}^K\\gamma_l^{(t)})\\right]$\n",
    "\n",
    ">>>规范化$\\eta_{nk}^{(t+1)}$使其和为1\n",
    "\n",
    ">>$\\gamma^{(t+1)}=\\alpha+\\sum_{n=1}^N\\eta_n^{(t+1)}$\n",
    "\n",
    ">直到收敛\n",
    "\n",
    "这里，$\\Psi(\\cdot)$为digamma函数，即：   \n",
    "\n",
    "$$\n",
    "\\Psi(x)=\\frac{d\\ ln\\ \\Gamma(x)}{dx}\n",
    "$$\n",
    "$\\Psi(\\cdot)$可以使用`scipy.special.digamma`直接求解，哈哈哈~\n",
    "\n",
    "#### 算法二：对LDA参数$\\alpha,\\varphi$估计\n",
    "基于上面的$\\gamma,\\eta$可以写出$\\alpha,\\varphi$的计算公式,...省略了推导过程....\n",
    "\n",
    "$$\n",
    "\\varphi_{kv}=\\sum_{m=1}^M\\sum_{n=1}^{N_m}\\eta_{mnk}w_{mn}^v\n",
    "$$  \n",
    "\n",
    "其中，$\\eta_{mnk}$表示第$m$个文本的第$n$个单词属于第$k$个话题的概率，$w_{mn}^v$在第$m$个文本的第$n$个单词是单词集合的第$v$个单词时取值为1，否则为0，而$\\alpha$的更新为  \n",
    "\n",
    "$$\n",
    "\\alpha_{new}=\\alpha_{old}-H(\\alpha_{old})^{-1}g(\\alpha_{old})\n",
    "$$  \n",
    "\n",
    "其中，$g(\\cdot)$表示其梯度，计算公式为：   \n",
    "\n",
    "$$\n",
    "\\frac{\\partial L}{\\partial \\alpha_k}=M\\left[\\Psi(\\sum_{l=1}^K\\alpha_l)-\\Psi(\\alpha_k) \\right]-\\sum_{m=1}^M\\left[\\Psi(\\gamma_{mk})-\\Psi(\\sum_{l=1}^K\\gamma_{ml})\\right]\n",
    "$$  \n",
    "\n",
    "而$H$表示Hessian矩阵，计算公式如下：   \n",
    "\n",
    "$$\n",
    "\\frac{\\partial^2 L}{\\partial\\alpha_k\\partial\\alpha_l}=M\\left[\\Psi'\\left(\\sum_{l=1}^K\\alpha_l\\right)-\\delta(k,l)\\Psi'(\\alpha_k)\\right]\n",
    "$$  \n",
    "\n",
    "其中，$\\delta(k,l)$是delta函数\n",
    "\n",
    "所以，对**算法一**和**算法二**交替迭代，直到收敛即是我们想要的结果，另外对于$\\alpha$的更新，为了方便，笔者将Hessian矩阵的逆$H^{-1}$这一部分替换为学习率进行计算，即将二阶的牛顿法替换为一阶的梯度下降法（其实是能力有限，不知道$\\psi'$和$delta$如何计算...哈哈哈）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "隐狄利克雷分布的代码实现，包括Gibbs采样和变分EM算法，代码封装在ml_models.latent_dirichlet_allocation\n",
    "\"\"\"\n",
    "import numpy as np\n",
    "from scipy.special import digamma\n",
    "\n",
    "\n",
    "class LDA(object):\n",
    "    def __init__(self, alpha=None, beta=None, K=10, tol=1e-3, epochs=100, method=\"gibbs\", lr=1e-5):\n",
    "        \"\"\"\n",
    "        :param alpha: 主题分布的共轭狄利克雷分布的超参数\n",
    "        :param beta: 单词分布的共轭狄利克雷分布的超参数\n",
    "        :param K: 主题数量\n",
    "        :param tol:容忍度，允许tol的隐变量差异\n",
    "        :param epochs:最大迭代次数\n",
    "        :param method:优化方法，默认gibbs,另外还有变分EM,vi_em\n",
    "        :param lr:学习率,对vi_em生效\n",
    "        \"\"\"\n",
    "        self.alpha = alpha\n",
    "        self.beta = beta\n",
    "        self.K = K\n",
    "        self.tol = tol\n",
    "        self.epochs = epochs\n",
    "        self.method = method\n",
    "        self.lr = lr\n",
    "        self.phi = None  # 主题-单词矩阵\n",
    "\n",
    "    def _init_params(self, W):\n",
    "        \"\"\"\n",
    "        初始化参数\n",
    "        :param W:\n",
    "        :return:\n",
    "        \"\"\"\n",
    "        M = len(W)  # 文本数\n",
    "        V = 0  # 词典大小\n",
    "        I = 0  # 单词总数\n",
    "        for w in W:\n",
    "            V = max(V, max(w))\n",
    "            I += len(w)\n",
    "        V += 1  # 包括0\n",
    "        # 文本话题计数\n",
    "        N_M_K = np.zeros(shape=(M, self.K))\n",
    "        N_M = np.zeros(M)\n",
    "        # 话题单词计数\n",
    "        N_K_V = np.zeros(shape=(self.K, V))\n",
    "        N_K = np.zeros(self.K)\n",
    "        # 初始化隐状态,计数矩阵\n",
    "        Z = []  # 隐状态，与W一一对应\n",
    "        p = [1 / self.K] * self.K\n",
    "        hidden_status = list(range(self.K))\n",
    "        for m, w in enumerate(W):\n",
    "            z = np.random.choice(hidden_status, len(w), replace=True, p=p).tolist()\n",
    "            Z.append(z)\n",
    "            for n, k in enumerate(z):\n",
    "                v = w[n]\n",
    "                N_M_K[m][k] += 1\n",
    "                N_M[m] += 1\n",
    "                N_K_V[k][v] += 1\n",
    "                N_K[k] += 1\n",
    "        # 初始化alpha和beta\n",
    "        if self.alpha is None:\n",
    "            self.alpha = np.ones(self.K)\n",
    "        if self.beta is None:\n",
    "            self.beta = np.ones(V)\n",
    "        return Z, N_M_K, N_M, N_K_V, N_K, M, V, I, hidden_status\n",
    "\n",
    "    def _fit_gibbs(self, W):\n",
    "        \"\"\"\n",
    "        :param W: 文本集合[[...],[...]]\n",
    "        :return:\n",
    "        \"\"\"\n",
    "        Z, N_M_K, N_M, N_K_V, N_K, M, V, I, hidden_status = self._init_params(W)\n",
    "        for _ in range(self.epochs):\n",
    "            error_num = 0\n",
    "            for m, w in enumerate(W):\n",
    "                z = Z[m]\n",
    "                for n, topic in enumerate(z):\n",
    "                    word = w[n]\n",
    "                    N_M_K[m][topic] -= 1\n",
    "                    N_M[m] -= 1\n",
    "                    N_K_V[topic][word] -= 1\n",
    "                    N_K[topic] -= 1\n",
    "                    # 采样一个新k\n",
    "                    p = []  # 更新多项分布\n",
    "                    for k_ in range(self.K):\n",
    "                        p_ = (N_K_V[k_][word] + self.beta[word]) * (N_M_K[m][k_] + self.alpha[topic]) / (\n",
    "                            (N_K[k_] + np.sum(self.beta)) * (N_M[m] + np.sum(self.alpha)))\n",
    "                        p.append(p_)\n",
    "                    ps = np.sum(p)\n",
    "                    p = [p_ / ps for p_ in p]\n",
    "                    topic_new = np.random.choice(hidden_status, 1, p=p)[0]\n",
    "                    if topic_new != topic:\n",
    "                        error_num += 1\n",
    "                    Z[m][n] = topic_new\n",
    "                    N_M_K[m][topic_new] += 1\n",
    "                    N_M[m] += 1\n",
    "                    N_K_V[topic_new][word] += 1\n",
    "                    N_K[topic_new] += 1\n",
    "            if error_num / I < self.tol:\n",
    "                break\n",
    "\n",
    "        # 计算参数phi\n",
    "        self.phi = N_K_V / np.sum(N_K_V, axis=1, keepdims=True)\n",
    "\n",
    "    def _fit_vi_em(self, W):\n",
    "        \"\"\"\n",
    "        分为两部分，迭代计算：\n",
    "        （1）给定lda参数，更新变分参数\n",
    "        （2）给定变分参数，更新lda参数\n",
    "        :param W:\n",
    "        :return:\n",
    "        \"\"\"\n",
    "        V = 0  # 词典大小\n",
    "        for w in W:\n",
    "            V = max(V, max(w))\n",
    "        V += 1\n",
    "        M = len(W)\n",
    "\n",
    "        # 给定lda参数，更新变分参数\n",
    "        def update_vi_params(alpha, phi):\n",
    "            eta = []\n",
    "            gamma = []\n",
    "            for w in W:\n",
    "                N = len(w)\n",
    "                eta_old = np.ones(shape=(N, self.K)) * (1 / self.K)\n",
    "                gamma_old = alpha + N / self.K\n",
    "                eta_new = np.zeros_like(eta_old)\n",
    "                for _ in range(self.epochs):\n",
    "                    for n in range(0, N):\n",
    "                        for k in range(0, self.K):\n",
    "                            eta_new[n, k] = phi[k, w[n]] * np.exp(digamma(gamma_old[k]) - digamma(np.sum(gamma_old)))\n",
    "                    eta_new = eta_new / np.sum(eta_new, axis=1, keepdims=True)\n",
    "                    gamma_new = alpha + np.sum(eta_new, axis=0)\n",
    "                    if (np.sum(np.abs(gamma_new - gamma_old)) + np.sum(np.abs((eta_new - eta_old)))) / (\n",
    "                                (N + 1) * self.K) < self.tol:\n",
    "                        break\n",
    "                    else:\n",
    "                        eta_old = eta_new.copy()\n",
    "                        gamma_old = gamma_new.copy()\n",
    "                eta.append(eta_new)\n",
    "                gamma.append(gamma_new)\n",
    "            return eta, gamma\n",
    "\n",
    "        # 给定变分参数，更新lda参数\n",
    "        def update_lda_params(eta, gamma, alpha_old):\n",
    "            # 更新phi\n",
    "            phi = np.zeros(shape=(self.K, V))\n",
    "            for m, w in enumerate(W):\n",
    "                for n, word in enumerate(w):\n",
    "                    for k in range(0, self.K):\n",
    "                        for v in range(0, V):\n",
    "                            phi[k, v] += eta[m][n, k] * (word == v)\n",
    "            # 更新alpha\n",
    "            d_alpha = []\n",
    "            for k, alpha_ in enumerate(alpha_old):\n",
    "                tmp = M * (digamma(np.sum(alpha_old)) - digamma(alpha_))\n",
    "                for m in range(M):\n",
    "                    tmp -= (digamma(gamma[m][k]) - digamma(np.sum(gamma[m])))\n",
    "                d_alpha.append(tmp)\n",
    "            alpha_new = alpha_old - self.lr * np.asarray(d_alpha)\n",
    "            alpha_new = np.where(alpha_new < 0.0, 0.0, alpha_new)\n",
    "            alpha_new = alpha_new / (1e-9 + np.sum(alpha_new)) * self.K\n",
    "            phi = phi / (np.sum(phi, axis=1, keepdims=True) + 1e-9)\n",
    "            return alpha_new, phi\n",
    "\n",
    "        # 初始化alpha和phi\n",
    "        alpha_old = np.random.random(self.K)\n",
    "        phi_old = np.random.random(size=(self.K, V))\n",
    "        phi_old = phi_old / np.sum(phi_old, axis=1, keepdims=True)\n",
    "        for _ in range(self.epochs):\n",
    "            eta, gamma = update_vi_params(alpha_old, phi_old)\n",
    "            alpha_new, phi_new = update_lda_params(eta, gamma, alpha_old)\n",
    "            if (np.sum(np.abs(alpha_new - alpha_old)) + np.sum(np.abs((phi_new - phi_old)))) / (\n",
    "                        (V + 1) * self.K) < self.tol:\n",
    "                break\n",
    "            else:\n",
    "                alpha_old = alpha_new.copy()\n",
    "                phi_old = phi_new.copy()\n",
    "        self.phi = phi_new\n",
    "\n",
    "    def fit(self, W):\n",
    "        if self.method == \"gibbs\":\n",
    "            self._fit_gibbs(W)\n",
    "        else:\n",
    "            self._fit_vi_em(W)\n",
    "\n",
    "    def transform(self, W):\n",
    "        rst = []\n",
    "        for w in W:\n",
    "            tmp = np.zeros(shape=self.K)\n",
    "            for v in w:\n",
    "                try:\n",
    "                    v_ = self.phi[:, v]\n",
    "                except:\n",
    "                    v_ = np.zeros(shape=self.K)\n",
    "                tmp += v_\n",
    "            if np.sum(tmp) > 0:\n",
    "                tmp = tmp / np.sum(tmp)\n",
    "            rst.append(tmp)\n",
    "        return np.asarray(rst)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 六.测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "docs=[\n",
    "    [\"有\",\"微信\",\"红包\",\"的\",\"软件\"],\n",
    "    [\"微信\",\"支付\",\"不行\",\"的\"],\n",
    "    [\"我们\",\"需要\",\"稳定的\",\"微信\",\"支付\",\"接口\"],\n",
    "    [\"申请\",\"公众号\",\"认证\"],\n",
    "    [\"这个\",\"还有\",\"几天\",\"放\",\"垃圾\",\"流量\"],\n",
    "    [\"可以\",\"提供\",\"聚合\",\"支付\",\"系统\"]\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[0, 1, 2, 3, 4],\n",
       " [1, 5, 6, 3],\n",
       " [7, 8, 9, 1, 5, 10],\n",
       " [11, 12, 13],\n",
       " [14, 15, 16, 17, 18, 19],\n",
       " [20, 21, 22, 5, 23]]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "word2id={}\n",
    "idx=0\n",
    "W=[]\n",
    "for doc in docs:\n",
    "    tmp=[]\n",
    "    for word in doc:\n",
    "        if word in word2id:\n",
    "            tmp.append(word2id[word])\n",
    "        else:\n",
    "            word2id[word]=idx\n",
    "            idx+=1\n",
    "            tmp.append(word2id[word])\n",
    "    W.append(tmp)\n",
    "W"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "lda=LDA(epochs=200,method=\"vi_em\")\n",
    "lda.fit(W)\n",
    "trans=lda.transform(W)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.15263755445005087"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第二句和第三句应该比较近似，因为它们都含有“微信”，“支付”\n",
    "trans[1].dot(trans[2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.0021851131099540695"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#而第二句和第四句的相似度显然不如第二句和第三句\n",
    "trans[1].dot(trans[3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.00023384778414162343"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#当然第二句和第五句的差距也有些大\n",
    "trans[1].dot(trans[4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.15459772783826697"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#而第一句和第二句都含有“微信”，所以相似度会比第四、五句高，但这里比第三句高...\n",
    "trans[1].dot(trans[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "从结果来看还基本能接受的，还有训练速度会比gibbs快不少，另外代码效率还有不少的优化空间~~~"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
